import difflib
import inspect
import os
import re
import time
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from io import StringIO
from pathlib import Path
from typing import Self

import torch
from sensai.util import logging
from sensai.util.git import GitStatus, git_status
from sensai.util.pickle import dump_pickle, load_pickle


def format_log_message(
    logger: logging.Logger,
    level: int,
    msg: str,
    formatter: logging.Formatter,
    stacklevel: int = 1,
) -> str:
    """
    Formats a log message as it would have been created by `logger.log(level, msg)` with the given formatter.

    :param logger: the logger
    :param level: the log level
    :param msg: the message
    :param formatter: the formatter
    :param stacklevel: the stack level of the function to report as the generator
    :return: the formatted log message (not including trailing newline)
    """
    frame_info = inspect.stack()[stacklevel]
    pathname = frame_info.filename
    lineno = frame_info.lineno
    func = frame_info.function

    record = logger.makeRecord(
        name=logger.name,
        level=level,
        fn=pathname,
        lno=lineno,
        msg=msg,
        args=(),
        exc_info=None,
        func=func,
        extra=None,
    )
    record.created = time.time()
    record.asctime = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.created))

    return formatter.format(record)


class TraceLogger:
    """Supports the collection of behavioural trace logs, which can, in particular, be used for determinism tests."""

    is_enabled = False
    """
    whether the trace logger is enabled.

    NOTE: The preferred way to enable this is via the context manager.
    """
    verbose = False
    """
    whether to print trace log messages to stdout.
    """
    MESSAGE_TAG = "[TRACE]"
    """
    a tag which is added at the beginning of log messages generated by this logger
    """
    LOG_LEVEL = logging.DEBUG
    log_buffer: StringIO | None = None
    log_formatter: logging.Formatter | None = None

    @classmethod
    def log(cls, logger: logging.Logger, message_generator: Callable[[], str]) -> None:
        """
        Logs a message intended for tracing agent-env interaction, which is enabled via
        `TraceAgentEnvLoggerContext`.

        :param logger: the logger to use for the actual logging
        :param message_generator: function which generates the log message (which may be expensive);
            if logging is disabled, the function will not be called.
        """
        if not cls.is_enabled:
            return

        msg = message_generator()
        msg = cls.MESSAGE_TAG + " " + msg

        # Log with caller's frame info
        logger.log(logging.DEBUG, msg, stacklevel=2)

        # If a dedicated memory buffer is configured, also store the message there
        if cls.log_buffer is not None:
            msg_formatted = format_log_message(
                logger,
                logging.DEBUG,
                msg,
                cls.log_formatter,
                stacklevel=2,
            )
            cls.log_buffer.write(msg_formatted + "\n")
            if cls.verbose:
                print(msg_formatted)


@dataclass
class TraceLog:
    log_lines: list[str]

    def save_log(self, path: str) -> None:
        with open(path, "w") as f:
            for line in self.log_lines:
                f.write(line + "\n")

    def print_log(self) -> None:
        for line in self.log_lines:
            print(line)

    def get_full_log(self) -> str:
        return "\n".join(self.log_lines)

    def reduce_log_to_messages(self) -> "TraceLog":
        """
        Removes logger names and function names from the log entries, such that each log message
        contains only the main text message itself (starting with the content after the logger's tag).

        :return: the result with reduced log messages
        """
        lines = []
        tag = re.escape(TraceLogger.MESSAGE_TAG)
        for line in self.log_lines:
            lines.append(re.sub(r".*" + tag, "", line))
        return TraceLog(lines)

    def filter_messages(
        self,
        required_messages: Sequence[str] = (),
        optional_messages: Sequence[str] = (),
        ignored_messages: Sequence[str] = (),
    ) -> "TraceLog":
        """
        Applies inclusion and or exclusion filtering to the log messages.
        If either `required_messages` or `optional_messages` is empty, inclusion filtering is applied.
        If `ignored_messages` is empty, exclusion filtering is applied.
        If both inclusion and exclusion filtering are applied, the exclusion filtering takes precedence.

        :param required_messages: required message substrings to filter for; each message is required to appear at least once
            (triggering exception otherwise)
        :param optional_messages: additional messages fragments to filter for; these are not required
        :param ignored_messages: message fragments that result in exclusion; takes precedence over
            `required_messages` and `optional_messages`
        :return: the result with reduced log messages
        """
        import numpy as np

        required_message_counters = np.zeros(len(required_messages))

        def retain_line(line: str) -> bool:
            for ignored_message in ignored_messages:
                if ignored_message in line:
                    return False
            if required_messages or optional_messages:
                for i, main_message in enumerate(required_messages):
                    if main_message in line:
                        required_message_counters[i] += 1
                        return True
                return any(add_message in line for add_message in optional_messages)
            else:
                return True

        lines = []
        for line in self.log_lines:
            if retain_line(line):
                lines.append(line)

        assert np.all(
            required_message_counters > 0,
        ), "Not all types of required messages were found in the trace. Were log messages changed?"

        return TraceLog(lines)


class TraceLoggerContext:
    """
    A context manager which enables the trace logger.
    Apart from enabling the logging, it can optionally create a memory log buffer, such that
    getting the trace log is not strictly dependent on the logging system.
    """

    def __init__(
        self,
        enable_log_buffer: bool = True,
        log_format: str = "%(name)s:%(funcName)s - %(message)s",
    ) -> None:
        """
        :param enable_log_buffer: whether to enable the dedicated log buffer for trace logs, whose contents
            can, within the context of this manager, be accessed via method `get_log`.
        :param log_format: the logger format string to use for the dedicated log buffer
        """
        self._enable_log_buffer = enable_log_buffer
        self._log_format: str = log_format
        self._log_buffer: StringIO | None = None

    def __enter__(self) -> Self:
        TraceLogger.is_enabled = True

        if self._enable_log_buffer:
            TraceLogger.log_buffer = StringIO()
            TraceLogger.log_formatter = logging.Formatter(self._log_format)
            self._log_buffer = TraceLogger.log_buffer

        return self

    def __exit__(self, exc_type, exc_val, exc_tb):  # type: ignore
        TraceLogger.is_enabled = False
        TraceLogger.log_buffer = None
        TraceLogger.log_formatter = None

    def get_log(self) -> TraceLog:
        """:return: the full trace log that was captured if `enable_log_buffer` was enabled at construction"""
        if self._log_buffer is None:
            raise Exception(
                "This method is only supported if the log buffer is enabled at construction",
            )
        return TraceLog(log_lines=self._log_buffer.getvalue().split("\n"))


def torch_param_hash(module: torch.nn.Module) -> str:
    """
    Computes a hash of the parameters of the given module; parameters not requiring gradients are ignored.

    :param module: a torch module
    :return: a hex digest of the parameters of the module
    """
    import hashlib

    hasher = hashlib.sha1()
    for param in module.parameters():
        if param.requires_grad:
            np_array = param.detach().cpu().numpy()
            hasher.update(np_array.tobytes())
    return hasher.hexdigest()


class TraceDeterminismTest:
    def __init__(
        self,
        base_path: Path,
        core_messages: Sequence[str] = (),
        ignored_messages: Sequence[str] = (),
        log_filename: str | None = None,
    ) -> None:
        """
        :param base_path: the directory where the reference results are stored (will be created if necessary)
        :param core_messages: message fragments that make up the core of a trace; if empty, all messages are considered core
        :param ignored_messages: message fragments to ignore in the trace log (if any); takes precedence over
            `core_messages`
        :param log_filename: the name of the log file to which results are to be written (if any)
        """
        base_path.mkdir(parents=True, exist_ok=True)
        self.base_path = base_path
        self.core_messages = core_messages
        self.ignored_messages = ignored_messages
        self.log_filename = log_filename

    @dataclass(kw_only=True)
    class Result:
        git_status: GitStatus
        log: TraceLog

    def check(
        self,
        current_log: TraceLog,
        name: str,
        create_reference_result: bool = False,
        pass_if_core_messages_unchanged: bool = False,
    ) -> None:
        """
        Checks the given log against the reference result for the given name.

        :param current_log: the result to check
        :param name: the name of the reference result; must be unique among all tests!
        :param create_reference_result: whether update the reference result with the given result
        """
        import pytest

        reference_result_path = self.base_path / f"{name}.pkl.bz2"
        current_git_status = git_status()

        if create_reference_result:
            current_result = self.Result(git_status=current_git_status, log=current_log)
            dump_pickle(current_result, reference_result_path)

        reference_result: TraceDeterminismTest.Result = load_pickle(
            reference_result_path,
        )
        reference_log = reference_result.log

        current_log_reduced = current_log.reduce_log_to_messages().filter_messages(
            ignored_messages=self.ignored_messages,
        )
        reference_log_reduced = reference_log.reduce_log_to_messages().filter_messages(
            ignored_messages=self.ignored_messages,
        )

        results: list[tuple[TraceLog, str]] = [
            (reference_log_reduced, "expected"),
            (current_log_reduced, "current"),
            (reference_log, "expected_full"),
            (current_log, "current_full"),
        ]

        if self.core_messages:
            result_main_messages = current_log_reduced.filter_messages(
                required_messages=self.core_messages,
            )
            reference_result_main_messages = reference_log_reduced.filter_messages(
                required_messages=self.core_messages,
            )
            results.extend(
                [
                    (reference_result_main_messages, "expected_core"),
                    (result_main_messages, "current_core"),
                ],
            )
        else:
            result_main_messages = current_log_reduced
            reference_result_main_messages = reference_log_reduced

        logs_equivalent = current_log_reduced.get_full_log() == reference_log_reduced.get_full_log()
        if logs_equivalent:
            status_passed = True
            status_message = "OK"
        else:
            core_messages_unchanged = (
                len(self.core_messages) > 0
                and result_main_messages.get_full_log()
                == reference_result_main_messages.get_full_log()
            )
            status_passed = core_messages_unchanged and pass_if_core_messages_unchanged

            if status_passed:
                status_message = "OK (core messages unchanged)"
            else:
                # save files for comparison
                files = []
                for r, suffix in results:
                    path = os.path.abspath(f"determinism_{name}_{suffix}.txt")
                    r.save_log(path)
                    files.append(path)

                paths_str = "\n".join(files)
                main_message = (
                    f"Please inspect the changes by diffing the log files:\n{paths_str}\n"
                    f"If the changes are OK, enable the `create_reference_result` flag temporarily, "
                    "rerun the test and then commit the updated reference file.\n\nHere's the first part of the diff:\n"
                )

                # compute diff and add to message
                num_diff_lines_to_show = 30
                for i, line in enumerate(
                    difflib.unified_diff(
                        reference_log_reduced.log_lines,
                        current_log_reduced.log_lines,
                        fromfile="expected.txt",
                        tofile="current.txt",
                        lineterm="",
                    ),
                ):
                    if i == num_diff_lines_to_show:
                        break
                    main_message += line + "\n"

                if core_messages_unchanged:
                    status_message = (
                        "The behaviour log has changed, but the core messages are still the same (so this "
                        f"probably isn't an issue). {main_message}"
                    )
                else:
                    status_message = f"The behaviour log has changed; even the core messages are different. {main_message}"

        # write log message
        if self.log_filename:
            with open(self.log_filename, "a") as f:
                hr = "-" * 100
                f.write(f"\n\n{hr}\nName: {name}\n")
                f.write(f"Reference state: {reference_result.git_status}\n")
                f.write(f"Current state: {current_git_status}\n")
                f.write(f"Test result: {status_message}\n")

        if not status_passed:
            pytest.fail(status_message)
